--
-- Licensed to the Apache Software Foundation (ASF) under one or more
-- contributor license agreements.  See the NOTICE file distributed with
-- this work for additional information regarding copyright ownership.
-- The ASF licenses this file to You under the Apache License, Version 2.0
-- (the "License"); you may not use this file except in compliance with
-- the License.  You may obtain a copy of the License at
--
--     http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.
--

--- Collection of util functions.
--
-- @module core.utils

local config_local   = require("apisix.core.config_local")
local core_str       = require("apisix.core.string")
local rfind_char     = core_str.rfind_char
local table          = require("apisix.core.table")
local log            = require("apisix.core.log")
local string         = require("apisix.core.string")
local dns_client     = require("apisix.core.dns.client")
local ngx_re         = require("ngx.re")
local ipmatcher      = require("resty.ipmatcher")
local ffi            = require("ffi")
local base           = require("resty.core.base")
local open           = io.open
local sub_str        = string.sub
local str_byte       = string.byte
local tonumber       = tonumber
local tostring       = tostring
local re_gsub        = ngx.re.gsub
local re_match       = ngx.re.match
local re_gmatch      = ngx.re.gmatch
local type           = type
local io_popen       = io.popen
local C              = ffi.C
local ffi_string     = ffi.string
local get_string_buf = base.get_string_buf
local exiting        = ngx.worker.exiting
local ngx_sleep      = ngx.sleep
local ipairs         = ipairs

local hostname
local dns_resolvers
local current_inited_resolvers
local current_dns_client
local max_sleep_interval = 1

ffi.cdef[[
    int ngx_escape_uri(char *dst, const char *src,
        size_t size, int type);
]]


local _M = {
    version = 0.2,
    parse_ipv4 = ipmatcher.parse_ipv4,
    parse_ipv6 = ipmatcher.parse_ipv6,
}


function _M.get_seed_from_urandom()
    local frandom, err = open("/dev/urandom", "rb")
    if not frandom then
        return nil, 'failed to open /dev/urandom: ' .. err
    end

    local str = frandom:read(8)
    frandom:close()
    if not str then
        return nil, 'failed to read data from /dev/urandom'
    end

    local seed = 0
    for i = 1, 8 do
        seed = 256 * seed + str:byte(i)
    end

    return seed
end


function _M.split_uri(uri)
    return ngx_re.split(uri, "/")
end


local function dns_parse(domain, selector)
    if dns_resolvers ~= current_inited_resolvers then
        local local_conf = config_local.local_conf()
        local valid = table.try_read_attr(local_conf, "apisix", "dns_resolver_valid")
        local enable_resolv_search_opt = table.try_read_attr(local_conf, "apisix",
                                                             "enable_resolv_search_opt")
        local opts = {
            nameservers = table.clone(dns_resolvers),
            order = {"last", "A", "AAAA", "CNAME"}, -- avoid querying SRV
        }

        opts.validTtl = valid

        if not enable_resolv_search_opt then
            opts.search = {}
        end

        local client, err = dns_client.new(opts)
        if not client then
            return nil, "failed to init the dns client: " .. err
        end

        current_dns_client = client
        current_inited_resolvers = dns_resolvers
    end

    return current_dns_client:resolve(domain, selector)
end
_M.dns_parse = dns_parse


local function set_resolver(resolvers)
    dns_resolvers = resolvers
end
_M.set_resolver = set_resolver


function _M.get_resolver(resolvers)
    return dns_resolvers
end


local function _parse_ipv4_or_host(addr)
    local pos = rfind_char(addr, ":", #addr - 1)
    if not pos then
        return addr, nil
    end

    local host = sub_str(addr, 1, pos - 1)
    local port = sub_str(addr, pos + 1)
    return host, tonumber(port)
end


local function _parse_ipv6_without_port(addr)
    return addr
end


-- parse_addr parses 'addr' into the host and the port parts. If the 'addr'
-- doesn't have a port, nil is used to return.
-- For IPv6 literal host with brackets, like [::1], the square brackets will be kept.
-- For malformed 'addr', the returned value can be anything. This method doesn't validate
-- if the input is valid.
function _M.parse_addr(addr)
    if str_byte(addr, 1) == str_byte("[") then
        -- IPv6 format, with brackets, maybe with port
        local right_bracket = str_byte("]")
        local len = #addr
        if str_byte(addr, len) == right_bracket then
            -- addr in [ip:v6] format
            return addr, nil
        else
            local pos = rfind_char(addr, ":", #addr - 1)
            if not pos or str_byte(addr, pos - 1) ~= right_bracket then
                -- malformed addr
                return addr, nil
            end

            -- addr in [ip:v6]:port format
            local host = sub_str(addr, 1, pos - 1)
            local port = sub_str(addr, pos + 1)
            return host, tonumber(port)
        end

    else
        -- When we reach here, the input can be:
        -- 1. IPv4
        -- 2. IPv4, with port
        -- 3. IPv6, like "2001:db8::68" or "::ffff:192.0.2.1"
        -- 4. Malformed input
        -- 5. Host, like "test.com" or "localhost"
        -- 6. Host with port
        local colon = str_byte(":")
        local colon_counter = 0
        local dot = str_byte(".")
        for i = 1, #addr do
            local ch = str_byte(addr, i, i)
            if ch == dot then
                return _parse_ipv4_or_host(addr)
            elseif ch == colon then
                colon_counter = colon_counter + 1
                if colon_counter == 2 then
                    return _parse_ipv6_without_port(addr)
                end
            end
        end

        return _parse_ipv4_or_host(addr)
    end
end


function _M.uri_safe_encode(uri)
    local count_escaped = C.ngx_escape_uri(nil, uri, #uri, 0)
    local len = #uri + 2 * count_escaped
    local buf = get_string_buf(len)
    C.ngx_escape_uri(buf, uri, #uri, 0)

    return ffi_string(buf, len)
end


function _M.validate_header_field(field)
    for i = 1, #field do
        local b = str_byte(field, i, i)
        -- '!' - '~', excluding ':'
        if not (32 < b and b < 127) or b == 58 then
            return false
        end
    end
    return true
end


function _M.validate_header_value(value)
    if type(value) ~= "string" then
        return true
    end

    for i = 1, #value do
        local b = str_byte(value, i, i)
        -- control characters
        if b < 32 or b >= 127 then
            return false
        end
    end
    return true
end


---
-- Returns the standard host name of the local host.
-- only use this method in init/init_worker phase.
--
-- @function core.utils.gethostname
-- @treturn string The host name of the local host.
-- @usage
-- local hostname = core.utils.gethostname() -- "localhost"
function _M.gethostname()
    if hostname then
        return hostname
    end

    local hd = io_popen("/bin/hostname")
    local data, err = hd:read("*a")
    if err == nil then
        hostname = data
        if string.has_suffix(hostname, "\r\n") then
            hostname = sub_str(hostname, 1, -3)
        elseif string.has_suffix(hostname, "\n") then
            hostname = sub_str(hostname, 1, -2)
        end

    else
        hostname = "unknown"
        log.error("failed to read output of \"/bin/hostname\": ", err)
    end

    return hostname
end


local function sleep(sec)
    if sec <= max_sleep_interval then
        return ngx_sleep(sec)
    end
    ngx_sleep(max_sleep_interval)
    if exiting() then
        return
    end
    sec = sec - max_sleep_interval
    return sleep(sec)
end


_M.sleep = sleep


local resolve_var
do
    local _ctx
    local n_resolved
    local pat = [[(?<!\\)\$(\{(\w+)\}|(\w+))]]
    local _escaper

    local function resolve(m)
        local variable = m[2] or m[3]
        local v = _ctx[variable]

        if v == nil then
            return ""
        end
        n_resolved = n_resolved + 1
        if _escaper then
            return _escaper(tostring(v))
        end
        return tostring(v)
    end

    function resolve_var(tpl, ctx, escaper)
        n_resolved = 0
        if not tpl then
            return tpl, nil, n_resolved
        end

        local from = core_str.find(tpl, "$")
        if not from then
            return tpl, nil, n_resolved
        end

        -- avoid creating temporary function
        _ctx = ctx
        _escaper = escaper
        local res, _, err = re_gsub(tpl, pat, resolve, "jo")
        _ctx = nil
        _escaper = nil
        if not res then
            return nil, err
        end

        return res, nil, n_resolved
    end
end
-- Resolve ngx.var in the given string
_M.resolve_var = resolve_var


local resolve_var_with_captures
do
    local _captures
    -- escape is not supported very well, like there is a redundant '\' after escape "$1"
    local pat = [[ (?<! \\) \$ \{? (\d+) \}? ]]

    local function resolve(m)
        local v = _captures[tonumber(m[1])]
        if not v then
            v = ""
        end
        return v
    end

    -- captures is the match result of regex uri in proxy-rewrite plugin
    function resolve_var_with_captures(tpl, captures)
        if not tpl then
            return tpl, nil
        end

        local from = core_str.find(tpl, "$")
        if not from then
            return tpl, nil
        end

        captures = captures or {}

        _captures = captures
        local res, _, err = re_gsub(tpl, pat, resolve, "jox")
        _captures = nil
        if not res then
            return nil, err
        end

        return res, nil
    end
end
-- Resolve {$1, $2, ...} in the given string
_M.resolve_var_with_captures = resolve_var_with_captures


-- if `str` is a string containing period `some_plugin.some_field.nested_field`
-- return the table that contains `nested_field` in its root level
-- else return the original table `conf`
local function get_root_conf(str, conf, field)
    -- if the string contains periods, get the splits in `it` iterator
    local it, _ = re_gmatch(str, [[([^\.]+)]])
    if not it then
        return conf, field
    end

    -- add the splits into a table
    local matches = {}
    while true do
        local m, _ = it()
        if not m then
            break
        end
        table.insert(matches, m[0])
    end

    -- get to the table that holds the last field
    local num_of_matches = #matches
    for i = 1, num_of_matches - 1 , 1 do
        conf = conf[matches[i]]
    end

    -- return the table and the last field
    return conf, matches[num_of_matches]
end


local function find_and_log(field, plugin_name, value)
    local match, err = re_match(value, "^https")
    if not match and not err then
        log.warn("Using ", plugin_name, " " , field, " with no TLS is a security risk")
    end
end


function _M.check_https(fields, conf, plugin_name)
    for _, field in ipairs(fields) do

        local new_conf, new_field = get_root_conf(field, conf)
        if not new_conf then
            return
        end

        local value = new_conf[new_field]
        if not value then
            return
        end

        if type(value) == "table" then
            for _, v in ipairs(value) do
                find_and_log(field, plugin_name, v)
            end
        else
            find_and_log(field, plugin_name, value)
        end
    end
end


function _M.check_tls_bool(fields, conf, plugin_name)
    for i, field in ipairs(fields) do

        local new_conf, new_field = get_root_conf(field, conf)
        if not new_conf then
            return
        end

        local value = new_conf[new_field]

        if value ~= true and value ~= nil then
            log.warn("Keeping ", field, " disabled in ",
                     plugin_name, " configuration is a security risk")
        end
    end
end


return _M
